# -*- coding: utf-8 -*-
"""
Created on Tue Jun 16 15:58:07 2020

@author: Colleen
"""
import pandas as pd
import cv2
import numpy as np
from sklearn import tree
from sklearn.model_selection import train_test_split
from sklearn.metrics import classification_report 
import os
os.environ["PATH"] += os.pathsep + 'F:/graphviz-2.38/release/bin'
from sklearn.externals import joblib

def binaryzation(img):
    cv_img = img.astype(np.uint8)
    cv2.threshold(cv_img, 50, 1, cv2.THRESH_BINARY_INV, cv_img)
    return cv_img


def binaryzation_features(trainset):
    features = []

    for img in trainset:
        img = np.reshape(img, (28, 28))
        cv_img = img.astype(np.uint8)

        img_b = binaryzation(cv_img)
        # hog_feature = np.transpose(hog_feature)
        features.append(img_b)

    features = np.array(features)
    features = np.reshape(features, (-1, feature_len))
    return features

#class_num = 10  # MINST数据集有10种labels，分别是“0,1,2,3,4,5,6,7,8,9”
feature_len = 13  # MINST数据集每个image有28*28=784个特征（pixels）
epsilon = 0.001  # 设定阈值

if __name__ == '__main__':
    
    #raw_data = pd.read_csv('train.csv', header=0)  # 读取csv数据
    raw_data = pd.read_csv('test_data.csv', header=0)
    data = raw_data.values

    #imgs = data[::, 1::]
    #features = binaryzation_features(imgs)  # 图片二值化(很重要，不然预测准确率很低)
    labels = data[::, 0]
    features = data[::, 1::]

    # 避免过拟合，采用交叉验证，随机选取33%数据作为测试集，剩余为训练集
    train_features, test_features, train_labels, test_labels = train_test_split(features, labels, test_size=0.2,random_state=0)
    clf = tree.DecisionTreeClassifier(criterion='entropy', random_state=0)
    clf = clf.fit(train_features, train_labels)
    print("clf:" + str(clf))
    
    #画图保存成pdf文件
    import pydotplus
    
    dot_data = tree.export_graphviz(clf, out_file=None,
                             filled=True, rounded=True,
                             special_characters=True)
    graph = pydotplus.graph_from_dot_data(dot_data)
    # 保存图像到pdf文件
    graph.write_pdf("mnist.pdf")
    
     
    #保存model
    joblib.dump(clf, "train_model.m")
    #读取模型
    clf2 = joblib.load("train_model.m")
    
    #测试值预测
    y_predict = clf2.predict(test_features)
    y_predict = clf.predict(test_features)
    #预测值和测试值打分
    score = classification_report(test_labels, y_predict)
    print(score)
    
    #测试一条
    test1 = [0, 0, -1, 1, 1, 1,-2, 0, 0, 2, -1, 0 ] #1
    test2 = [0, 0, 1, 2, 2, 2, 0, -2, -2, 0 ,1, 0] #2
    test3 = [0, 0, 1, 1, 1, 2, 0, 0, 0, 2, -1, 0 ]#3
    y_predict1 = clf.predict([test1])
    y_predict1 = clf.predict([test3])
    print(y_predict1)
    

